# Copyright (c) 2024 Huawei Technologies Co., Ltd.
#
# openMind is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

import sys

from openmind.flow.arguments import get_args, initialize_openmind
from openmind.flow.train import run_sft, run_pt, run_dpo, run_rm
from openmind.flow.callbacks import get_swanlab_callbacks
from openmind.utils.constants import Stages


def run_train(**kwargs):
    # initialize params
    command_line = sys.argv[1::]
    if len(command_line) == 1 and command_line[-1].endswith("yaml"):
        yaml_file = command_line[-1]
    else:
        yaml_file = None
    initialize_openmind(yaml_file, **kwargs)
    args = get_args()

    callbacks = get_swanlab_callbacks()

    if args.stage == Stages.SFT:
        run_sft(callbacks)
    elif args.stage == Stages.PT:
        run_pt(callbacks)
    elif args.stage == Stages.DPO:
        run_dpo()
    elif args.stage == Stages.RM:
        run_rm()


if __name__ == "__main__":
    run_train()
